# ------------------------------------------------------------------------------
# Runs analysis on cost effectiveness thresholds
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 01_daly_cutoff_analysis.sh {overnight}
# ------------------------------------------------------------------------------

# Libraries --------------------------------------------------------------------
library(here)
library(yaml)
library(data.table)
library(scales)
library(tidyverse)
library(optparse)
library(glue) # glue
library(ggplot2)

u <- modules::use(here::here("lib", "util.R"))
temp <- here("code", "04_cost_effectiveness", "temp")
a <- modules::use(here("lib", "aesthetics.R"))
a$get_font("Optima", here("lib", "optima.ttf"))

# Command Line Args ------------------------------------------------------------
arg_config <- list(make_option("--overnight", type = "logical"))
arg_parser <- OptionParser(option_list = arg_config)
arg_list <- parse_args(arg_parser)

temp <- ifelse(arg_list$overnight, file.path(temp, "overnight"), temp)

# Helpers ----------------------------------------------------------------------
get_avg_cost <- function(test_type = "all", ...) {
  df <- copy(cohort) %>% filter(test_010_day)
  if (test_type == "cath") {
    df <- filter(df, first_test == "cath")
  }
  avg_cost_daly <- df %>%
    group_by(tile_stent_or_cabg_005_tested) %>%
    summarize(
      n = n(),
      avg_cost = mean(cost, na.rm = TRUE),
      avg_daly_250 = mean(daly_250, na.rm = TRUE),
      cost_per_daly_100 = sum(cost, na.rm = TRUE) / sum(daly_100, na.rm = TRUE),
      cost_per_daly_200 = sum(cost, na.rm = TRUE) / sum(daly_200, na.rm = TRUE),
      cost_per_daly_250 = sum(cost, na.rm = TRUE) / sum(daly_250, na.rm = TRUE),
      cost_per_daly_300 = sum(cost, na.rm = TRUE) / sum(daly_300, na.rm = TRUE),
      cost_per_daly_500 = sum(cost, na.rm = TRUE) / sum(daly_500, na.rm = TRUE)
    ) %>%
    ungroup() # %>%
  # mutate(cost_per_daly_200 = ifelse(cost_per_daly_200 > 12e5, 12e5, cost_per_daly_200)) %>%
  # mutate(cost_per_daly_100 = ifelse(cost_per_daly_100 > 12e5, 12e5, cost_per_daly_100))
  return(avg_cost_daly)
}

get_avg_overall <- function(test_type = "all", ...) {
  df <- copy(cohort) %>% filter(test_010_day)
  if (test_type == "cath") {
    df <- filter(df, first_test == "cath")
  }
  avg_costeff <- sum(df$cost, na.rm=TRUE)/sum(df$daly_250, na.rm = TRUE)

  return(floor(avg_costeff))
}

plot_avg_cost_daly <- function(
                               avg_cost_daly, avg_cost_overall, test_type = "all", interval_lo = 300, interval_hi = 200, ...) {
  daly_scale <- function(x) {
    case_when(
      x == 70000 ~ "$70",
      x == 100000 ~ "$100",
      x == 170000 ~ "$170",
      x == 200000 ~ "$200",
      x == 300000 ~ "$300",
      TRUE ~ scales::dollar(x / 1000)
    )
  }
  gradient <- c("#7EB59F", "#BCCD63", "#EED979", "#FCBB6D", "#D49478")

  interval_vars <- c(
    glue("cost_per_daly_{interval_lo}"),
    glue("cost_per_daly_{interval_hi}")
  )

  cutoff_costeff_lab <- "Cost-effectiveness threshold: $150,000   "

  plot_df <- copy(avg_cost_daly) %>%
    setnames(interval_vars, c("costeff_lo", "costeff_hi"))

  test_lab <- case_when(
    test_type == "all" ~ "Testing",
    test_type == "cath" ~ "Catheterization",
    test_type == "stress" ~ "Stress Testing"
  )
  title <- glue("Cost Effectivenss of {test_lab} by Predicted Risk")

  gg_costcurve <- ggplot(
    plot_df,
    aes(
      x = tile_stent_or_cabg_005_tested,
      y = cost_per_daly_250
    )
  ) +
    labs(
      x = "Percentile of Predicted Risk",
      y = "Cost per Life-Year (in Thousands)"
    ) +
    scale_y_continuous(
      labels = daly_scale,
      breaks = c(seq(1e5, 14e5, by = 1e5))
    ) +
    geom_ribbon(
      aes(
        ymin = costeff_lo,
        ymax = costeff_hi,
        color = NA
      ),
      alpha = 0.4, fill = a$disc_palette[1]
    ) +
    geom_line(linetype = "dotted", color = a$disc_palette[1]) +
    geom_point(color = a$disc_palette[1]) +
    theme_bw() +
    scale_x_continuous(breaks = 1:5, labels = (1:5) * 20) +
    theme(text = element_text(family = "Optima", size = 50, lineheight = 0.25)) +
    geom_hline(yintercept = 150000, alpha = 1, color = "black", linetype = "dashed") +
    annotate("label", x = 2, y = 150000, label = cutoff_costeff_lab, hjust = 0, alpha = 1, family = "Optima", size = 16) +
    geom_hline(yintercept = 70000, alpha = 0.5, color = gradient[1], linetype = "dashed") +
    geom_hline(yintercept = 100000, alpha = 0.5, color = gradient[2], linetype = "dashed") +
    geom_hline(yintercept = 170000, alpha = 0.5, color = gradient[3], linetype = "dashed") +
    geom_hline(yintercept = 200000, alpha = 0.5, color = gradient[4], linetype = "dashed") +
    geom_hline(yintercept = 300000, alpha = 0.5, color = gradient[5], linetype = "dashed") +
    annotate("label", x = 1, y = 70000, label = "Dialysis  ", hjust = 0, size = 8, family = "Optima", fill = gradient[1]) +
    annotate("label", x = 1, y = 100000, label = "XR Lung Cancer Screen  ", hjust = 0, size = 8, family = "Optima", fill = gradient[2]) +
    annotate("label", x = 1, y = 170000, label = "Cancer Immunotherapy  ", hjust = 0, size = 8, family = "Optima", fill = gradient[3]) +
    annotate("label", x = 1, y = 200000, label = "Mammography, 40-50yo  ", hjust = 0, size = 8, family = "Optima", fill = gradient[4]) +
    annotate("label", x = 1, y = 300000, label = "Biologics for Rare Diseases  ", hjust = 0, size = 8, family = "Optima", fill = gradient[5]) +
    scale_color_manual(values = a$disc_palette) +
    scale_fill_manual(values = a$disc_palette) +
    coord_cartesian(ylim = c(0, 14e5))
}
get_filename <- function(
                         test_type = "all", interval_lo = 300, interval_hi = 200, ...) {
  test_lab <- ifelse(test_type == "all", "", glue("_{test_type}"))
  fn <- glue("COST_CURVE{test_lab}_{interval_lo}_{interval_hi}.png")
  fp <- file.path(temp, fn)
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here("lib", "filepaths.yml"))
overnight_lab <- ifelse(arg_list$overnight, "_overnight", "")
daly_cost <- readRDS(paths$analysis$daly_cost) %>%
  select(ed_enc_id, cost, daly_100, daly_200, daly_250, daly_300, daly_500)
cohort <- readRDS(glue(paths$analysis$test_cohort)) %>%
  u$safe_left_join(daly_cost) %>%
  filter(!exclude)

# Config -----------------------------------------------------------------------
message("Configuring plotting parameters...")
config <- crossing(
  test_type = c("all", "cath"),
  interval_lo = c(500, 300), interval_hi = c(100, 200)
) %>%
  filter(
    (interval_lo == 500 & interval_hi == 100) |
      (interval_lo == 300 & interval_hi == 200)
  )

# Prep Data --------------------------------------------------------------------
message("Plotting cost-effectiveness curves...")
plots_df <- config %>%
  mutate(avg_cost_daly = pmap(., get_avg_cost)) %>%
  mutate(avg_cost_overall = pmap(., get_avg_overall) %>% unlist) %>%
  mutate(plot = pmap(., plot_avg_cost_daly)) %>%
  mutate(filename = pmap(., get_filename))

message("Saving...")
plots_df %>%
  select(plot, filename) %>%
  pmap(ggsave, width = 10, height = 7, unit = "in")

message("Done.")
